// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>

// CHECK-LABEL: func @reshape
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
//      CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[INIT]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x112x16xf32> into tensor<?x16xf32>
//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
//      CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<?x16xf32> into tensor<?x112x16xf32>
//      CHECK: return %[[RR]] : tensor<?x112x16xf32>
func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
  %0 = linalg.tensor_reshape %A [
    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
      : tensor<?x16xf32> into tensor<?x112x16xf32>
  %2 = linalg.generic {indexing_maps = [
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
  ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>)
  outs(%init : tensor<?x112x16xf32>) {
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
    %s = subf %arg1, %arg2 : f32
    linalg.yield %s : f32
  } -> tensor<?x112x16xf32>
  return %2 : tensor<?x112x16xf32>
}

// -----

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>

// CHECK-LABEL: func @reshape_multiple
// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>)
//      CHECK: %[[I:.*]] = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
//      CHECK: %[[RI:.*]] = linalg.tensor_reshape %[[I]] [#[[$MAP0]], #[[$MAP1]]] : tensor<112x112x16xf32> into tensor<12544x16xf32>
//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
//      CHECK: %[[RR:.*]] = linalg.tensor_reshape %[[R]] [#[[$MAP0]], #[[$MAP1]]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
//      CHECK: return %[[RR]] : tensor<112x112x16xf32>
func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
  %C: tensor<16xf32>) -> tensor<112x112x16xf32> {
  %0 = linalg.tensor_reshape %A [
    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
      : tensor<12544x16xf32> into tensor<112x112x16xf32>
  %1 = linalg.tensor_reshape %B [
    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
      : tensor<12544x16xf32> into tensor<112x112x16xf32>
  %2 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
  %3 = linalg.generic {indexing_maps = [
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
    affine_map<(d0, d1, d2) -> (d2)>,
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
  ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>)
  outs(%2 : tensor<112x112x16xf32>) {
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
    %s = subf %arg1, %arg2 : f32
    %m = mulf %s, %arg3 : f32
    linalg.yield %m : f32
  } -> tensor<112x112x16xf32>
  return %3 : tensor<112x112x16xf32>
}

// -----

// Negative test, since the second source is broadcasted from d1 we cannot merge
// d0 and d1 dimensions
// CHECK-LABEL: func @reshape_negative
// CHECK: linalg.tensor_reshape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: linalg.generic
// CHECK: } -> tensor<112x112x16xf32>
func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
  %20 = linalg.tensor_reshape %A [
    affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
      : tensor<12544x16xf32> into tensor<112x112x16xf32>
  %21 = linalg.init_tensor [112, 112, 16] : tensor<112x112x16xf32>
  %22 = linalg.generic {indexing_maps = [
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>,
    affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]}
  ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>)
  outs(%21 : tensor<112x112x16xf32>) {
  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):  // no predecessors
    %s = subf %arg1, %arg2 : f32
    linalg.yield %s : f32
  } -> tensor<112x112x16xf32>
  return %22 : tensor<112x112x16xf32>
}
